import os
import json

import numpy as np
import pickle
import torch
from torch.utils.data import Dataset
from PIL import Image
from torch.nn.utils.rnn import pad_sequence
from scipy.spatial.transform import Rotation as R

from data.create_cliport_programs import utterance2program_bdetr as utterance2program
from data.parser_dataset import LangUtils
from data.cliport_info import SAME_GOAL, DIFF_GOAL

import ipdb
st = ipdb.set_trace

MAX_FRAMES = 5
XMIN = -0.5
XMAX = 0.5
YMIN = -1
YMAX = 1
SIZE = 0.15


class NS_TransporterDataset(Dataset):
    """Dataset class for Transporter"""

    def __init__(self, annos_path, split='train', seed=42, filter=False,
                 repeat=1):
        self.annos_path = annos_path
        self.split = split
        self.seed = seed
        self.lang_utils = LangUtils()
        self.filter = filter

        self.filter_concepts = [f'packing-seen-google-objects-group-{self.split}', f'packing-unseen-google-objects-group-{self.split}']
        
        print(self.filter_concepts)
        self.reverse_gts = False
        self.shapes = ['circle', 'line']
        # load annos
        print(f"Loading Annotations for {split}. Take a breath...")
        self.annos = self.load_annos_google()
        _annos = []
        for _ in range(repeat):
            _annos.extend(self.annos)
        self.annos = _annos
        self.lens = np.array([XMAX - XMIN - SIZE, YMAX - YMIN - SIZE]) / 2.0

    def load_annos_no_gt(self):
        utterances = []
        initial_frames = []
        goal_frames = []
        
        for dir in os.listdir(self.annos_path):
            
            if self.filter and dir not in self.filter_concepts:
                continue            
            # load sentences
            if self.split in dir:
                lang = pickle.load(
                    open(os.path.join(
                        self.annos_path, dir, 'ebm/lang.pickle'
                ), 'rb'))
                
                mask_task_ends = np.array([
                    True if utt.startswith('done') else False for utt in lang
                    ])
                final_goal = True if dir in SAME_GOAL else False
                
                print("final_goal", final_goal)
                
                frame_index = 0
                sorted_file_names = sorted(os.listdir(os.path.join(self.annos_path, dir, 'ebm')))[:-1]
                
                frame_task_ends = np.where(mask_task_ends)[0]
                for i, t_num in enumerate(frame_task_ends):
                    # num_frames = frame_task_ends[i + 1] - frame_task_ends[i]
                    # step_index = 0
                    while frame_index != t_num:
                        utterances.append(lang[frame_index])
                        initial_frames.append(os.path.join(
                            self.annos_path, dir, 'ebm', sorted_file_names[frame_index]))
                        goal_index = t_num if final_goal else frame_index+1
                        goal_frames.append(os.path.join(
                            self.annos_path, dir, 'ebm', sorted_file_names[goal_index]))
                        frame_index += 1
                        
                    # also save goal frame
                    utterances.append(lang[frame_index-1])
                    initial_frames.append(os.path.join(
                            self.annos_path, dir, 'ebm', sorted_file_names[frame_index]))
                    goal_frames.append(os.path.join(
                            self.annos_path, dir, 'ebm', sorted_file_names[frame_index]))
                    
                    frame_index += 1
                    
        annos = [
            {
                'utterance': utt,
                'initial_frame': initial_frame,
                'goal_frame': goal_frame,
            } for utt, initial_frame, goal_frame in zip(
                utterances, initial_frames, goal_frames)
        ]  
        # json.dump(annos, open('debug.json', 'w'), indent=True)
        return annos[:1]
    
    def load_annos(self):
        utterances = []
        initial_frames = []
        goal_frames = []
        initial_frame_boxes = []
        goal_frame_boxes = []
        initial_ebm_boxes = []
        goal_ebm_boxes = []
        
        for dir in os.listdir(self.annos_path):
            
            if self.filter and dir not in self.filter_concepts:
                continue
            
            for shape in self.shapes:
                if shape in dir:
                    self.reverse_gts = True
                    break
            print(f"Reverse Ground truths: {self.reverse_gts}")
            
            # load sentences
            if self.split in dir:
                lang = pickle.load(
                    open(os.path.join(
                        self.annos_path, dir, 'ebm/lang.pickle'
                ), 'rb'))
                gt_detections = pickle.load(
                    open(os.path.join(
                        self.annos_path, dir, 'ebm/gt_bbox.pickle'
                ), 'rb'))
                
                mask_task_ends = np.array([
                    True if utt.startswith('done') else False for utt in lang
                    ])
                final_goal = True if dir in SAME_GOAL else False
                
                print("final_goal", final_goal)
                
                frame_index = 0
                sorted_file_names = sorted(os.listdir(os.path.join(self.annos_path, dir, 'ebm')))[:-1]
                
                frame_task_ends = np.where(mask_task_ends)[0]
                for i, t_num in enumerate(frame_task_ends):
                    # num_frames = frame_task_ends[i + 1] - frame_task_ends[i]
                    # step_index = 0
                    while frame_index != t_num:
                        utterances.append(lang[frame_index])
                        initial_frames.append(os.path.join(
                            self.annos_path, dir, 'ebm', sorted_file_names[frame_index]))
                        goal_index = t_num if final_goal else frame_index+1
                        goal_frames.append(os.path.join(
                            self.annos_path, dir, 'ebm', sorted_file_names[goal_index]))
                        
                        # ground truths in initial frame
                        frame_boxes, ebm_boxes  = self._load_gt_boxes(
                            gt_detections, sorted_file_names[frame_index])

                        initial_frame_boxes.append(frame_boxes)
                        initial_ebm_boxes.append(ebm_boxes)
                        
                        # ground truths in goal frame
                        frame_boxes, ebm_boxes  = self._load_gt_boxes(
                            gt_detections, sorted_file_names[goal_index])

                        frame_index += 1

                        goal_frame_boxes.append(frame_boxes)
                        goal_ebm_boxes.append(ebm_boxes)
                    
                    # also save goal frame
                    utterances.append(lang[frame_index-1])
                    initial_frames.append(os.path.join(
                            self.annos_path, dir, 'ebm', sorted_file_names[frame_index]))
                    goal_frames.append(os.path.join(
                            self.annos_path, dir, 'ebm', sorted_file_names[frame_index]))
                    
                    frame_boxes, ebm_boxes  = self._load_gt_boxes(
                            gt_detections, sorted_file_names[frame_index])
                    initial_frame_boxes.append(frame_boxes)
                    initial_ebm_boxes.append(ebm_boxes)
                    goal_frame_boxes.append(frame_boxes)
                    goal_ebm_boxes.append(ebm_boxes)

                    frame_index += 1

        annos = [
            {
                'utterance': utt,
                'initial_frame': initial_frame,
                'goal_frame': goal_frame,
                'initial_frame_boxes': ifb,
                'goal_frame_boxes': gfb,
                'initial_ebm_boxes': ieb,
                'goal_ebm_boxes': geb,
            } for utt, initial_frame, goal_frame, ifb, gfb, ieb, geb in zip(
                utterances, initial_frames, goal_frames, initial_frame_boxes,
                goal_frame_boxes, initial_ebm_boxes, goal_ebm_boxes)
        ]  
        # json.dump(annos, open('debug.json', 'w'), indent=True)
        return annos
    
    def load_annos_google(self):
        utterances = []
        initial_frames = []
        goal_frames = []
        initial_frame_boxes = []
        goal_frame_boxes = []
        initial_frame_classes = []
        goal_frame_classes = []
        
        for dir in os.listdir(self.annos_path):
            
            if self.filter and dir not in self.filter_concepts:
                continue
            
            # load sentences
            if self.split in dir:
                lang = pickle.load(
                    open(os.path.join(
                        self.annos_path, dir, 'ebm/lang.pickle'
                ), 'rb'))
                gt_detections = pickle.load(
                    open(os.path.join(
                        self.annos_path, dir, 'ebm/gt_bbox.pickle'
                ), 'rb'))
                
                mask_task_ends = np.array([
                    True if utt.startswith('done') else False for utt in lang
                    ])
                final_goal = True if dir in SAME_GOAL else False
                
                print("final_goal", final_goal)
                
                frame_index = 0
                sorted_file_names = sorted(os.listdir(os.path.join(self.annos_path, dir, 'ebm')))[:-1]
                
                frame_task_ends = np.where(mask_task_ends)[0]
                for i, t_num in enumerate(frame_task_ends):
                    # num_frames = frame_task_ends[i + 1] - frame_task_ends[i]
                    # step_index = 0
                    while frame_index != t_num:
                        utterances.append(lang[frame_index])
                        initial_frames.append(os.path.join(
                            self.annos_path, dir, 'ebm', sorted_file_names[frame_index]))
                        goal_index = t_num if final_goal else frame_index+1
                        goal_frames.append(os.path.join(
                            self.annos_path, dir, 'ebm', sorted_file_names[goal_index]))
                        
                        # ground truths in initial frame
                        all_boxes, all_classes = self._load_gt_boxes_google(
                            gt_detections, sorted_file_names[frame_index]
                        )
                        
                        initial_frame_boxes.append(all_boxes)
                        initial_frame_classes.append(all_classes)
                        
                        # ground truths in goal frame
                        all_boxes, all_classes  = self._load_gt_boxes_google(
                            gt_detections, sorted_file_names[goal_index])

                        frame_index += 1

                        goal_frame_boxes.append(all_boxes)
                        goal_frame_classes.append(all_classes)
                    
                    # also save goal frame
                    utterances.append(lang[frame_index-1])
                    initial_frames.append(os.path.join(
                            self.annos_path, dir, 'ebm', sorted_file_names[frame_index]))
                    goal_frames.append(os.path.join(
                            self.annos_path, dir, 'ebm', sorted_file_names[frame_index]))
                    
                    all_boxes, all_classes  = self._load_gt_boxes_google(
                            gt_detections, sorted_file_names[frame_index])
                    initial_frame_boxes.append(all_boxes)
                    initial_frame_classes.append(all_classes)
                    goal_frame_boxes.append(all_boxes)
                    goal_frame_classes.append(all_classes)

                    frame_index += 1

        annos = [
            {
                'utterance': utt,
                'initial_frame': initial_frame,
                'goal_frame': goal_frame,
                'initial_frame_boxes': ifb,
                'goal_frame_boxes': gfb,
                'initial_frame_classes': ieb,
                'goal_frame_classes': geb,
            } for utt, initial_frame, goal_frame, ifb, gfb, ieb, geb in zip(
                utterances, initial_frames, goal_frames, initial_frame_boxes,
                goal_frame_boxes, initial_frame_classes, goal_frame_classes)
        ]  
        # json.dump(annos, open('debug.json', 'w'), indent=True)
        return annos
    
    
    # @staticmethod
    # def _load_gt_boxes(gt_detections, key):
    #     # ground truths in initial frame
    #     fixed = [[k[0][0][1], k[0][0][0], k[0][1][1], k[0][1][0]] for k in 
    #                 list(gt_detections[key]['fixed'].values())]
    #     rigid = [[k[0][0][1], k[0][0][0], k[0][1][1], k[0][1][0]] for k in 
    #             list(gt_detections[key]['rigid'].values())]
    #     deformable = [[k[0][0][1], k[0][0][0], k[0][1][1], k[0][1][0]] for k in 
    #             list(gt_detections[key]['deformable'].values())]
    #     frame_boxes = np.array(fixed + rigid + deformable)
    #     ebm_boxes = [np.array(fixed), np.array(rigid)]
    #     return frame_boxes, ebm_boxes    
    
    def _load_gt_boxes(self, gt_detections, key):
        # ground truths in initial frame
        fixed, rigid, deformable = [], [], []
        
        all = {}
        for mode in ["fixed", "rigid", "deformable"]:
            temp = []
            for k in list(gt_detections[key][mode].values()):
                x1, y1, x2, y2 = k[0][0][1], k[0][0][0], k[0][1][1], k[0][1][0]
                quat = list(k[1])
                r = R.from_quat(quat)
                points = np.array([[x1, y1], [x2, y1], [x2, y2], [x1, y2]])
                points_ = np.zeros((4, 3))
                points_[:, :2] = points - points.mean(0)
                points_ = r.apply(points_)[:, :2] + points.mean(0)
                points_ = np.round(points_).astype(np.int)
                x1_ = np.min(points_[:, 0])
                y1_ = np.min(points_[:, 1])
                x2_ = np.max(points_[:, 0])
                y2_ = np.max(points_[:, 1])
                box = [x1_, y1_, x2_, y2_]
                temp.append(box)
            all[mode] = temp
            
        if not self.reverse_gts:
            fixed = all['fixed']
            rigid = all['rigid']
            deformable = all['deformable']
            frame_boxes = np.array(fixed + rigid + deformable)
            ebm_boxes = [np.array(fixed), np.array(rigid)]
        else:
            fixed = all['fixed']
            rigid = all['deformable']
            deformable = all['rigid']
            frame_boxes = np.array(deformable)
            ebm_boxes = [np.array(rigid), np.array(rigid)]

        return frame_boxes, ebm_boxes
        
    def _load_gt_boxes_google(self, gt_detections, key):
        key = key.split('_')[0] + '.jpg'
        all_detections = gt_detections[key]
        all_classes = np.array(all_detections)[0::2]
        all_boxes = np.array(all_detections)[1::2]
        if False:
            all_boxes = np.array([np.array([box[1], box[0], box[3], box[2]]).round().astype(np.int64) for box in all_boxes])
        else:
            all_boxes = np.array([np.array([box[0], box[1], box[2], box[3]]).round().astype(np.int64) for box in all_boxes])
        all_boxes = np.unique(all_boxes, axis=0)
        return all_boxes, all_classes
    
    def load_annos_old(self):
        utterances = []
        images = []
        task_names = []
        frame_names = []
        for dir in os.listdir(self.annos_path):
            # load sentences
            if self.split in dir:
                lang = pickle.load(
                    open(os.path.join(
                        self.annos_path, dir, 'ebm/lang.pickle'
                ), 'rb'))
                
                if self.filter:
                    for utt in lang:
                        for concept in self.filter_concepts:
                            if utt.startswith(concept):
                                utterances.append(utt)
                                break
                else:
                    utterances += lang
                
                # load images
                for i in range(len(lang)):
                    task_names.append(dir)
                    frame_names.append(str(i).zfill(6))
                    image_i = []
                    for file in sorted(os.listdir(os.path.join(self.annos_path, dir, 'ebm'))):
                        if file.endswith('.jpg') and file.startswith(str(i).zfill(6)):
                            img = np.asarray(Image.open(os.path.join(self.annos_path, dir, 'ebm', file)))
                            image_i.append(img)
                    if len(image_i) > MAX_FRAMES:
                        mask = np.random.choice(np.arange(1, len(image_i) - 1), MAX_FRAMES - 2, replace=False)
                        image_i_ = np.array(image_i)[mask]
                    else:
                        image_i_ = np.array(image_i[1:-1])
                    if len(image_i_) == 0:
                        print(utterances)
                        print(dir)
                        print(i)
                        print(len(image_i))
                        print(mask)
                        print("something bad")
                    img_all_frames = np.stack([image_i[0], *image_i_, image_i[-1]], axis=0)
                    images.append(img_all_frames)
        annos = [
            {
                'task_name': tname,
                'frame_name': fname,
                'utterance': utt,
                'image_frames': img_f,
            } for utt, img_f, tname, fname in zip(utterances, images, task_names, frame_names)
        ]

        return annos

    def __getitem__(self, index):
        """Get current batch for input index"""
        anno = self.annos[index]
        utterance = anno['utterance']
        program = utterance2program(utterance)
        raw_utterance, program_tree = self.lang_utils.get_program(
            utterance, program
        )
        initial_frame = np.asarray(Image.open(anno['initial_frame']))
        goal_frame = np.asarray(Image.open(anno['goal_frame']))

        H, W = initial_frame.shape[:2]
        initial_frame_boxes = self._clamp_boxes(anno['initial_frame_boxes'], H, W)
        goal_frame_boxes = self._clamp_boxes(anno['goal_frame_boxes'], H, W)
        initial_frame_classes = anno['initial_frame_classes']
        goal_frame_classes = anno['goal_frame_classes']
        
        # initial_ebm_boxes = [
        #     self._clamp_boxes(anno['initial_ebm_boxes'][0], H, W),
        #     self._clamp_boxes(anno['initial_ebm_boxes'][1], H, W)
        # ]
        # goal_ebm_boxes = [
        #     self._clamp_boxes(anno['goal_ebm_boxes'][0], H, W),
        #     self._clamp_boxes(anno['goal_ebm_boxes'][1], H, W)
        # ]
        # initial_ebm_boxes = [
        #     self._pack_boxes_for_ebm(anno['initial_ebm_boxes'][0], H, W),
        #     self._pack_boxes_for_ebm(anno['initial_ebm_boxes'][1], H, W)
        # ]
        # goal_ebm_boxes = [
        #     self._pack_boxes_for_ebm(anno['goal_ebm_boxes'][0], H, W),
        #     self._pack_boxes_for_ebm(anno['goal_ebm_boxes'][1], H, W)
        # ]
        # # Augment (like hell)
        # if self.split == 'train':
        #     boxes = np.concatenate([
        #         initial_ebm_boxes[0],
        #         initial_ebm_boxes[1],
        #         goal_ebm_boxes[0],
        #         goal_ebm_boxes[1]
        #     ])
        #     boxes = boxes.reshape(len(boxes) * 2, 2)
        #     boxes = self._augment_rot(boxes)
        #     boxes = self._augment_pos(boxes)
        #     boxes = self._augment_scale(boxes)
        #     boxes = boxes.reshape(-1, 4)
        #     len1 = len(initial_ebm_boxes[0])
        #     len2 = len(initial_ebm_boxes[1])
        #     initial_ebm_boxes = [
        #         boxes[:len1],
        #         boxes[len1:len1+len2]
        #     ]
        #     goal_ebm_boxes = [
        #         boxes[len1+len2:len1+len2+len1],
        #         boxes[-len2:]
        #     ]

        return {
            "raw_utterance": raw_utterance,
            "program_list": program,
            "program_tree": program_tree,
            "initial_frame": initial_frame,
            "goal_frame": goal_frame,
            "initial_frame_path": anno['initial_frame'],
            "goal_frame_path": anno['goal_frame'],
            'initial_frame_boxes': initial_frame_boxes,
            'goal_frame_boxes': goal_frame_boxes,
            'initial_frame_classes': initial_frame_classes,
            'goal_frame_classes': goal_frame_classes
            # 'initial_ebm_boxes': initial_ebm_boxes,
            # 'goal_ebm_boxes': goal_ebm_boxes,
        }

    def __len__(self):
        """Return number of utterances"""
        return len(self.annos)

    def _pack_boxes_for_ebm(self, boxes, height, width):
        # boxes are (x1, y1, x2, y2)
        # normalize
        boxes = np.copy(boxes).astype(float)
        boxes[..., (0, 2)] /= float(width)
        boxes[..., (1, 3)] /= float(height)

        # to center-size
        boxes = np.stack((
            (boxes[..., 0] + boxes[..., 2]) * 0.5,
            (boxes[..., 1] + boxes[..., 3]) * 0.5,
            boxes[..., 2] - boxes[..., 0],
            boxes[..., 3] - boxes[..., 1]
        ), -1)

        # scale
        boxes[..., 0] = boxes[..., 0]*(XMAX - XMIN) + XMIN
        boxes[..., 1] = boxes[..., 1]*(YMAX - YMIN) + YMIN
        return boxes  # batch

    @staticmethod
    def _clamp_boxes(frame_boxes, H, W):
        frame_boxes = np.clip(frame_boxes, a_min=0, a_max=None)
        frame_boxes = np.minimum(
            frame_boxes, np.array([W-1, H-1, W-1, H-1]).reshape(1, 4))
        return frame_boxes

    @staticmethod
    def _augment_rot(boxes):
        max_degrees = 180
        random_rot = (2 * np.random.rand() - 1) * max_degrees
        boxes = rot_z(boxes, random_rot)
        return boxes

    def _augment_pos(self, boxes):
        boxes = boxes.reshape(-1, 4)
        centers = 0.5 * (boxes[:, :2] + boxes[:, 2:])
        xmin, ymin = centers.min(0).flatten().tolist()
        xmax, ymax = centers.max(0).flatten().tolist()
        x_translation = np.random.uniform(
            -xmin - self.lens[0], self.lens[0] - xmax, 1
        )
        y_translation = np.random.uniform(
            -ymin - self.lens[1], self.lens[1] - ymax, 1
        )
        boxes[:, (0, 2)] += x_translation[None, :]
        boxes[:, (1, 3)] += y_translation[None, :]
        return boxes.reshape(-1, 2)

    def _augment_scale(self, boxes):
        imax = np.abs(boxes).max()
        scale = np.random.uniform(0.8, 1 / (imax + 1e-5), 1)
        return boxes * scale


def rot_z(pc, theta):
    """Rotate along z-axis."""
    theta = theta * np.pi / 180
    return np.matmul(
        np.array([
            [np.cos(theta), -np.sin(theta)],
            [np.sin(theta), np.cos(theta)]
        ]),
        pc.T
    ).T


def ns_transporter_collate_fn(batch):
    """Collate function for NS Transporter"""
    return {
        "raw_utterances": [ex["raw_utterance"] for ex in batch],
        "program_lists": [ex["program_list"] for ex in batch],
        "program_trees": [ex["program_tree"] for ex in batch],
        "initial_frames": torch.stack([torch.from_numpy(ex["initial_frame"]) for ex in batch], axis=0),
        "goal_frames": torch.stack([torch.from_numpy(ex["goal_frame"]) for ex in batch], axis=0),
        "initial_frame_paths": [ex['initial_frame_path'] for ex in batch],
        "goal_frame_paths": [ex['goal_frame_path'] for ex in batch],
        "initial_frame_boxes": [ex['initial_frame_boxes'] for ex in batch],
        "goal_frame_boxes": [ex['goal_frame_boxes'] for ex in batch],
        "initial_frame_classes": [ex['initial_frame_classes'] for ex in batch],
        "goal_frame_classes":  [ex['goal_frame_classes'] for ex in batch],
        # "initial_ebm_boxes": [
        #     pad_sequence([torch.from_numpy(ex["initial_ebm_boxes"][i]).float() for ex in batch], batch_first=True, padding_value=0)
        #     for i in range(len(batch[0]['initial_ebm_boxes']))
        # ],
        # "goal_ebm_boxes": [
        #     pad_sequence([torch.from_numpy(ex["goal_ebm_boxes"][i]).float() for ex in batch], batch_first=True, padding_value=0)
        #     for i in range(len(batch[0]['goal_ebm_boxes']))
        # ]
    }


if __name__=='__main__':
    NS = NS_TransporterDataset(
        annos_path='/projects/""/ns_transporter_data/grounder_dataset_10',
        split='test',
        filter=True
        )
    for i in range(len(NS.annos)):
        print(i)
        NS.__getitem__(i)
